{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "threaded-manufacturer",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n",
      "stack:  127\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "from utils import *\n",
    "from instructions import Instruction\n",
    "from settings import *\n",
    "from uncompute import *\n",
    "\n",
    "n = 1\n",
    "input_file = \"test_files/ult.btor2\"\n",
    "current_settings = get_btor2_settings(input_file)\n",
    "Instruction.all_instructions = read_file(input_file, modify_memory_sort=False, setting=current_settings)\n",
    "Instruction.with_grover = 0\n",
    "\n",
    "for i in range(1, n+1):\n",
    "    print(i)\n",
    "    Instruction.current_n = i\n",
    "    for instruction in Instruction.all_instructions.values():\n",
    "        if instruction[1] == INIT and i == 1:\n",
    "            Instruction(instruction).execute()\n",
    "        elif instruction[1] == NEXT or instruction[1] == BAD:\n",
    "            Instruction(instruction).execute()\n",
    "\n",
    "result_bad_states = Instruction.or_bad_states()\n",
    "\n",
    "print(\"stack: \", Instruction.global_stack.size)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "civilian-plaza",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{1: (QuantumRegister(8, 'q2'), [-1, -1, -1, -1, -1, -1, -1, -1])},\n",
       " {1: (QuantumRegister(8, 'q3'), [-1, -1, -1, -1, -1, -1, -1, -1])}]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Instruction.input_nids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "southern-retirement",
   "metadata": {},
   "outputs": [],
   "source": [
    "circuit_queue = get_circuit_queue(Instruction.global_stack)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "apart-outreach",
   "metadata": {},
   "outputs": [],
   "source": [
    "def are_all_controls_true(values, controls):\n",
    "    for c in controls:\n",
    "        if values[c] == 0:\n",
    "            return False\n",
    "        else:\n",
    "            assert(values[c] == 1)\n",
    "    return True\n",
    "\n",
    "def check_input(values):\n",
    "    # we only set the value of the first input the other ones are set to |0>\n",
    "    qubit_values = {}\n",
    "    for (index,qword) in enumerate(Instruction.input_nids):\n",
    "        for i in range(n+1):\n",
    "            if i in qword.states.keys():\n",
    "                temp_value = values[index]\n",
    "                for qubit in qword.states[i][0]:\n",
    "                    qubit_values[qubit] = temp_value % 2\n",
    "                    temp_value = temp_value // 2\n",
    "    element: Element = circuit_queue.pop()\n",
    "    assert(element.element_type != CHECKPOINT_TYPE)\n",
    "    while element.element_type != CHECKPOINT_TYPE:\n",
    "        for o in element.operands:\n",
    "            if o not in qubit_values.keys():\n",
    "                qubit_values[o] = 0\n",
    "\n",
    "        assert (element.target is not None)\n",
    "\n",
    "        if element.target not in qubit_values.keys():\n",
    "            qubit_values[element.target] = 0\n",
    "\n",
    "        flip_target = True\n",
    "        if element.gate_name == X:\n",
    "            assert(len(element.operands) == 0)\n",
    "\n",
    "        else:\n",
    "            assert((element.gate_name == CX and len(element.operands) ==1) or\n",
    "                   (element.gate_name == CCX and len(element.operands) == 2) or\n",
    "                    element.gate_name == MCX)\n",
    "\n",
    "            flip_target = are_all_controls_true(qubit_values, element.operands)\n",
    "        if flip_target:\n",
    "            qubit_values[element.target] = (qubit_values[element.target] + 1) % 2\n",
    "        circuit_queue.push(element)\n",
    "        element = circuit_queue.pop()\n",
    "    assert element.element_type == CHECKPOINT_TYPE\n",
    "    circuit_queue.push(element)\n",
    "    result = 0\n",
    "    for (index, qubit) in enumerate(Instruction.created_states_ids[4][1][0]):\n",
    "        result += (2**index)*qubit_values[qubit]\n",
    "    return result, qubit_values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "billion-anchor",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(256):\n",
    "    for j in range(256):\n",
    "        result, values = check_input([i,j])\n",
    "        local_result = i<j\n",
    "        assert(local_result == values[Instruction.created_states_ids[5][1][0][0]])\n",
    "        assert(result == local_result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "duplicate-facial",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
